using Plots
using SparseArrays
using LinearAlgebra
using Interpolations
eye(n) = Matrix{Float64}(I, n, n)
plotly()

# Calculate the continuum model energy spectrum of twisted X layer graphene with single twist angle
# Nlayer: number of layers
# config[1..Nlayer]: configuration of layers, layers with 1 here are twisted at theta with respect to layers with 0 here
# theta: twist angle in radians
# w[1..2]: diagonal and off-diagonal interlayer hopping terms. w[1]: diagonal, w[2]: off-diagonal
# shift[1..Nlayer, 1..2]: each row is lateral shift of that layer. Note that the two subsets of layers with different twist angle have independent shifts
# V[1..Nlayer]: electrical potential of each layer
# kpoints: array of k points (specified as 2-tuple) for evaluating energy
# maxM: cut-off momentum
# hbarVf: defines energy scale of the Dirac electrons
# Bin: optional in-plane magnetic field
# anisotropy: optional anisotropy in hopping energies
function txg_continuum_bands(Nlayer, config, theta,
     w, shift, V, kpoints;
     maxM=4.1, hbarVf = 2.7*sqrt(3)/2, Bin=[0,0], anisotropy=1, periodic=false, nostate=false)
    # All in the unit of a = 0.246 nm; eV
    # Basic Definitions
    sigx = [[0 1];[1 0]]
    sigy = [[0 -im];[im 0]]
    sigz = [[1 0];[0 -1]]
    @show Ktheta = 8pi/3*sin(theta/2) # Momentum displacement between two K-points due to twisting

    # Hopping matrices.. relaxation included; anisotropy is strain-> anything otheer than 1 is anisotropic
    r3 = (Complex(-1))^(2/3)
    w_tw = w[1]
    w_tw_off = w[2]
    H1 = [[w_tw w_tw_off]; [w_tw_off w_tw]] * anisotropy
    H2 = [[w_tw*r3^2 w_tw_off];[w_tw_off*r3 w_tw*r3^2]]
    H3 = [[w_tw*r3 w_tw_off];[w_tw_off*r3^2 w_tw*r3]]

    # Add phase to hopping matrices when interlayer shift is present
    # Phase factor for lateral shift. Need to do this for each layer <----------
    ph1 = [exp(im*4pi/3*[0,-1]'*collect(D)) for D in shift]
    ph2 = [exp(im*4pi/3*[sqrt(3)/2, 0.5]'*collect(D)) for D in shift]
    ph3 = [exp(im*4pi/3*[-sqrt(3)/2, 0.5]'*collect(D)) for D in shift]

    # In-plane field momentum boost, in unit of 1/a
    ca = 0.33e-9 * 0.246e-9  # m^2
    @show kbpar = 2pi * ca / 4.136e-15 * [Bin[2], -Bin[1]]


    # Find hex points around origin bounded by maxM
    # kx, ky, i, j, delta (indices)
    # h1 and h2 are positions of the hexagonal points of L1 and L2 respectively
    # in unit of Ktheta
    # 0,0 is middle point between K_s and K'_s, i.e. M_s
    g1 = [sqrt(3),0, 1, 0]
    g2 = [sqrt(3)/2, 3/2, 0, 1]
    h1 = [(g1.*i)+(g2*j) for i=-2*floor(maxM)-2:2maxM+2 for j=-2*floor(maxM)-2:2maxM+2]
    h1 = filter(h->(sum(h[1:2].^2)<=3*maxM^2), h1) # Remove points farther than maxM
    h1 = hcat(h1...)' #converting the array of array into 2d array.. '...'
    M = size(h1, 1)
    h2 = h1 .+ [0 -1 0 0]

    # Find hex connection
    # Cn[i,j] == 1 means there is a coupling from (+theta/2) layer i-th harmonic Dirac cone to (-theta/2) layer j-th harmonic Dirac cone,
    # along n=1,2,3 direction.
    C1=eye(M);
    C2=zeros(M,M);
    C3=zeros(M,M);
    for i=1:M
        for j=1:M
            if h1[i,3]==h2[j,3] && h1[i,4]==h2[j,4]-1
                    C2[i,j]=1 # H2 connection
            end
            if h1[i,3]==h2[j,3]+1 && h1[i,4]==h2[j,4]-1
                    C3[i,j]=1 # H3 connection
            end
        end
    end

    #-----> Size of hopping matrix <----------
    # Create common parts of the Hamiltonian, i.e independent of k --> interlayer
    H0 = zeros(Complex, 2M*Nlayer, 2M*Nlayer)

    # Generates index in the Hamiltonian from the sublattice, momentum and layer indices
    index(sublattice, momentum, layer) = sublattice + 2*((momentum-1) + M*(layer-1))


    #-----> Generic hopping for arbitrary layers <----------
    for i=1:Nlayer-1
        if config[i] ==0 && config[i+1] ==1
            H0[(1:2M).+(i-1)*2M, (1:2M).+(i*2M)] = kron(C1, H1*ph1[i]/ph1[i+1]) + kron(C2, H2*ph2[i]/ph2[i+1]) + kron(C3, H3*ph3[i]/ph3[i+1])
        elseif config[i] ==1 && config[i+1] ==0
            H0[(1:2M).+(i-1)*2M, (1:2M).+(i*2M)] = (kron(C1, H1*ph1[i]/ph1[i+1]) + kron(C2, H2*ph2[i]/ph2[i+1]) + kron(C3, H3*ph3[i]/ph3[i+1]))'
        else
        end
    end

    if periodic
        # Connect Nth layer back to the first layer
        if config[Nlayer] ==0 && config[1] ==1
            H0[(1:2M).+(Nlayer-1)*2M, 1:2M] = kron(C1, H1*ph1[Nlayer]/ph1[1]) + kron(C2, H2*ph2[Nlayer]/ph2[1]) + kron(C3, H3*ph3[Nlayer]/ph3[1])
        elseif config[Nlayer] ==1 && config[1] ==0
            H0[(1:2M).+(Nlayer-1)*2M, 1:2M] = (kron(C1, H1*ph1[Nlayer]/ph1[1]) + kron(C2, H2*ph2[Nlayer]/ph2[1]) + kron(C3, H3*ph3[Nlayer]/ph3[1]))'
        else
        end
    end

    # Hermitianize the hamiltonian
    H0 += H0'

    # Add interlayer potential
    H0 += kron(diagm(0=>V), eye(2*M))

    println("Size of Hamiltonian: $(2M*Nlayer)")

    # Iterate over kpoints and collect result
    E = zeros(2M*Nlayer, length(kpoints)) # Eigenvalues
    if (!nostate)
        V = zeros(Complex, 2M*Nlayer, 2M*Nlayer, length(kpoints)) # Eigenstates
    end
    L = zeros(2M*Nlayer, Nlayer, length(kpoints)) # Layer projection

    Threads.@threads for i=1:length(kpoints) # loop over all kpoints
        (kx,ky) = kpoints[i] #maps the resulting (E,V) down there to each k-point (kx,ky)

        println("Kpoint #$(i)")
        H = copy(H0)
        # Intralayer Hamiltonian
        for j=1:Nlayer
            if config[j]==0
                th = theta/2
                h = h1
            elseif config[j]==1
                th = -theta/2
                h = h2
            end
            for i=1:M
                ind = index(1, i, j)
                (qx, qy) = [kx, ky] * Ktheta .- h[i,1:2] * Ktheta .+ (j-(Nlayer+1)/2) * kbpar
                H[ind, ind+1] = -hbarVf * sqrt(qx^2+qy^2) * exp(im*(atan(qy, qx) - th))
                H[ind+1, ind] = conj(H[ind, ind+1])
            end
        end

        # Solve the eigen problem
        F = eigen(H)
        si=sortperm(F.values)
        E[:,i]=F.values[si]
        if (!nostate)
            V[:,:,i]=F.vectors[:, si] #internal DOF of eigen-vector, index of eigenval
            # Calculate the layer distribution for each eigenstate
            for j=1:Nlayer
                L[:,j,i]=sum(abs.(V[(1:2M).+(j-1)*2M,:,i]).^2; dims=1);
            end
        end
    end

    hh = vcat(h1, h2)
    #now, you want to extract what you want from the above array of tuples to desired form.
    #map here now lets you get an array of E or V separately
    #for V, you did reshape because hcat makes NxN*k, but we want NxNxk array
    if (!nostate)
        (E=E,V=V,L=L,H=H0, h=hh)
    else
        (E=E, H=H0, h=hh)
    end
end

function gen_k(klist, res)
    kp = []
    scale = []
    keys = [0.0]
    sc=0
    for i=1:size(klist,1)-1
        kx = range(klist[i][1], stop=klist[i+1][1], length=res) |> collect
        ky = range(klist[i][2], stop=klist[i+1][2], length=res) |> collect
        s = sqrt.((kx.-kx[1]).^2 .+ (ky.-ky[1]).^2)
        append!(kp, map((x,y)->(x,y), kx, ky))
        append!(scale, s.+sc)
        sc += s[end]
        append!(keys, sc)
    end
    (kp,scale,keys)
end

function plotTXGBands_screened(Nlayer, config, theta, res, shift; title="", Efield=0, Cq=1000, Cg=0.026, save="", Bin=[0,0], w=[0.08,0.1], anisotropy=1, periodic=false)
    K=(0,0)
    M=(0,-0.5)
    Kp=(0,-1)
    Gamma=(sqrt(3)/2, -0.5)
    Gamma2=(-sqrt(3)/2, -0.5)
    (kpoints,scale,keys) = gen_k([Kp,K,Gamma,M,Kp], res)
    (kpointsKp,scale,keys) = gen_k([2 .*Gamma.-Kp,2 .*Gamma.-K,Gamma,2 .*Gamma.-M,2 .*Gamma.-Kp], res)
    # append!(kpoints, gen_k([K,Kp,Gamma2,M,K], res)[1])
    #append!(kpoints, gen_k([Kp,Kp,Kp,K], res)[1])

    g = Cg/Cq # Quantum capacitance ratio
    @show V0 = 8.85e-12 * Efield / Cg
    B = zeros(Nlayer,1); B[1] = V0*g; B[end] = -V0*g
    A = diagm(0 => [1+g;ones(Nlayer-2)*(1+2g);1+g], 1=>ones(Nlayer-1)*(-g), -1=>ones(Nlayer-1)*(-g))
    @show V = (A\B)[:]


    D=txg_continuum_bands(Nlayer, config, theta, w, shift, V, kpoints; Bin=Bin, anisotropy=anisotropy, periodic=periodic)
    DKp=txg_continuum_bands(Nlayer, config, theta, w, shift, V, kpointsKp; Bin=-Bin, anisotropy=anisotropy, periodic=periodic)
    Nbands = size(D.E, 1)
    Dirac = Int(Nbands/2)
    E = D.E[Dirac-9:Dirac+10, :] #band index, momentum index
    EKp = DKp.E[Dirac-9:Dirac+10, :] #band index, momentum index
    X = D.V[:,Dirac-9:Dirac+10,:] #internal DOF, band index, momentum index
    # @show Edirac = maximum(D.E[Dirac, :])


    #calculate parity of the e-vecs
    # @show m = Int32(size(D.h,1)/2)
    # x1= X[1:2m,:,:]
    # x3= X[(4m+1):6m, :, :]
    # n1 = sqrt.(sum(x1.*conj.(x1),dims=1)) #norm
    # n3 = sqrt.(sum(x3.*conj.(x3),dims=1))
    # dp = sum(x1.*conj.(x3),dims=1)./n1./n3  #dot product
    # dp = real.(dp[1,:,:])

    # g=plot(scale, vcat(dp[:,1:Int(end/2)],dp[:,Int(end/2)+1:end])')
    # h=plot(scale, vcat(E[:,1:Int(end/2)],E[:,Int(end/2)+1:end])', ylims=(-0.15,0.15), legend=false, xticks=(keys,["K'","K","G","M","K'"]))
    h=plot(scale, E', ylims=(-0.15,0.15), legend=false, xticks=(keys,["K'","K","G","M","K'"]))
    plot!(scale, EKp')
    xlabel!(h,"Momentum")
    ylabel!(h,"E (eV)")
    title!(h,title)

    # h=plot(h, g, scatter(D.h[:,1], D.h[:,2], aspectratio=1))


    if save != ""
        fp = open(save, "w")
        foreach(repeat(scale,1,Int(size(E,2))), E', EKp') do x,E1,E2
            println(fp,"$x\t$E1\t$E2")
            if x==scale[end,1]
                println(fp)
            end
        end
        close(fp)
    end

    h
end
#
function plotTXGBands3D_screened(Nlayer, config, theta, res, shift; title="", Efield=0, Cq=1000, Cg=0.026, save="", Bin=[0,0], w=[0.08,0.1], anisotropy=1)

    g1=(sqrt(3), 0)
    g2=(sqrt(3)/2, 1.5)
    gamma=(sqrt(3)/2,-0.5)
    k1=(sqrt(3)/2,0.5)
    k2=(0,-1.0)
    kpoints = [k1.*i .+ k2.*j for i=range(0,stop=1,length=res+1) for j=range(0,stop=1,length=res+1)]

    # d=0.33e-9
    # V=(-(Nlayer-1)/2:(Nlayer-1)/2) * (d*Efield) # Interlayer potential
    g = Cg/Cq # Quantum capacitance ratio
    @show V0 = 8.85e-12 * Efield / Cg
    B = zeros(Nlayer,1); B[1] = V0*g; B[end] = -V0*g
    A = diagm(0 => ones(Nlayer)*(1+2g), 1=>ones(Nlayer-1)*(-g), -1=>ones(Nlayer-1)*(-g))
    @show V = (A\B)[:]

    D=txg_continuum_bands(Nlayer, config, theta, w, shift, V, kpoints; Bin=Bin, anisotropy=anisotropy)
    Nbands = size(D.E, 1)
    Dirac = Int(Nbands/2)
    #E = D.E[Dirac-7:Dirac+8, :]
    @show size(D.E)
    @show size(D.L)
    E = reshape(D.E, :, res+1, res+1) # E[i,j,k] j,k labels momentum, i labels band
    L = reshape(D.L, :, Nlayer, res+1, res+1) # band index, layer, and momentum (last two indices)

    # Rotation matrix
    R3 = [-0.5 sqrt(3)/2; -sqrt(3)/2 -0.5]

    # Extend by rotation
    kall = map(x->x.-gamma, kpoints)
    append!(kall,  map(x->(x[1],x[2]), map(x->R3*collect(x.-gamma), kpoints)))
    append!(kall,  map(x->(x[1],x[2]), map(x->R3*R3*collect(x.-gamma), kpoints)))

    kx = reshape(map(x->x[1], kall), res+1, (res+1) * 3)
    ky = reshape(map(x->x[2], kall), res+1, (res+1) * 3)

    pyplot()
    h=plot(legend=false, layout=Nlayer,  camera=(45,5), size=(1200,1600), clims=(0.0,1.0))
    for layer = 1:Nlayer
        for i = -1:2
            plot!(kx[:,1:res+1],ky[:,1:res+1],E[Dirac+i,:,:], linetype=:surface, xlabel="kx",ylabel="ky",zlabel="E (eV)",fill_z=L[Dirac+i,layer,:,:],clims=(0.0,0.001), subplot=layer, title="Layer $layer")
            plot!(kx[:,res+2:2res+2],ky[:,res+2:2res+2],E[Dirac+i,:,:], linetype=:surface,fill_z=L[Dirac+i,layer,:,:], subplot=layer)
            plot!(kx[:,2res+3:3res+3],ky[:,2res+3:3res+3],E[Dirac+i,:,:], linetype=:surface,fill_z=L[Dirac+i,layer,:,:], subplot=layer)
        end
    end

    if save != ""
        fp = open(save, "w")
        for k=Dirac-4:Dirac+5
            for i=1:res+1
                for j=1:res+1
                    println(fp,"$(kx[i,j])\t$(ky[i,j])\t$(E[k,i,j])")
                end
                println(fp,"")
            end
            println(fp,"")
        end
        close(fp)
    end

    h
end




# xpt and ypt are 1-D arrays. x can be any shape
function interp1(xpt, ypt, x; method="linear", extrapvalue=nothing)
    if extrapvalue == nothing
        extrapvalue = NaN
    elseif extrapvalue == :linear
        extrapvalue = Line()
    elseif extrapvalue isa Number
        extrapvalue = extrapvalue |> Float64
    else
        throw(ArgumentError("Extrapvalue can be nothing or :linear"))
    end
    if method == "linear"
        intf = LinearInterpolation(xpt, ypt, extrapolation_bc = extrapvalue)
        return  intf.(x)
    elseif method == "cubic"
        intf = CubicSplineInterpolation(xpt, ypt, extrapolation_bc = extrapvalue)
        return  intf.(x)
    end
end

cross2(p1::Tuple{Real,Real}, p2::Tuple{Real,Real}) = p1[1]*p2[2] - p1[2]*p2[1]

function filltriangle(Ebin, pts, Es)
    idx = sortperm(Es)
    pts = pts[idx]
    Es = Es[idx]
    (Emin,Emid,Emax) = Es
    Emin = clamp(Emin,Ebin[1], Ebin[end])
    Emax = clamp(Emax,Ebin[1], Ebin[end])
    Emid = clamp(Emid,Ebin[1], Ebin[end])
    d = zeros(length(Ebin)-1)

    if Emax == Emin && (Emax == Ebin[1] || Emax == Ebin[end])
        # Nothing to do here, since everything is out of range
        return d
    end
    pmid = (interp1([Emin, Emax], [pts[1][1], pts[3][1]], Emid),
            interp1([Emin, Emax], [pts[1][2], pts[3][2]], Emid)) # Intersection of plane z=E and line pts[1]-pts[3]
    area1 = abs(cross2(pts[2].-pmid, pts[1].-pmid)) / 2 # area of pts[1]-pts[2]-pmid, from energy Emin to Emid
    area2 = abs(cross2(pts[2].-pmid, pts[3].-pmid)) / 2 # area of pts[3]-pts[2]-pmid, from energy Emid to Emax

    @assert abs(area1+area2-abs(cross2(pts[2].-pts[1], pts[3].-pts[1]))/2)<1e-11

    #crosslen = Array(pmid .- pts[2]) .^ 2 |> sum |> sqrt # k-space length from pts[2] to the line connected by pts[1] and pts[3]

    if Emax-Emin == 0
        return d
    end

    for i=1:length(Ebin)-1
        if Emax <= Ebin[i] || Emin >= Ebin[i+1]
            # Energy out of bin
            continue;
        end

        Emidc = clamp(Emid, Ebin[i], Ebin[i+1])

        if Emid > Ebin[i] && Emid != Emin
            # Calculate area from Eminc to Emidc
            Eminc = max(Emin, Ebin[i])
            d[i] += ((Emidc-Emin)^2 - (Eminc-Emin)^2) / (Emid - Emin)^2 * area1
        end
        if Emid < Ebin[i+1] && Emid != Emax
            # Calculate area from Emidc to Emaxc
            Emaxc = min(Ebin[i+1], Emax)
            d[i] += ((Emax-Emidc)^2 - (Emax - Emaxc)^2) / (Emax - Emid)^2 * area2
        end
    end

    @assert abs(sum(d)-area1-area2)/(area1+area2) < 1e-11 "Error: sum of triangles not conserved"
    return d
end

function calcDOS(kx, ky, E, res, bin, theta; save="")
    # g1=(sqrt(3), 0)
    # g2=(sqrt(3)/2, 3/2)
    # kpoints = [g1.*i .+ g2.*j for i=range(0,stop=1,length=res+1)[1:end-1] for j=range(0,stop=1,length=res+1)[1:end-1]]
    #
    # #append!(kpoints, gen_k([Kp,Kp,Kp,K], res)[1])
    # w_tw=0.05
    # w_tw_off=0.1
    # d=0.34e-9
    # V=[-1,0,1] * (d*Efield) + [0.5,-0.5,0.5] * (d*Efield_symm) # Interlayer potential
    #
    # D=ttg_continuum_bands(theta, w_tw, w_tw_off, kpoints, maxM=4, V=V, D=shift, Bin=Bin)
    Nbands = size(E, 1)
    Dirac = Int(Nbands/2)
    bindex = Dirac-1:Dirac+2
    E = E[bindex, :, :]
    kpoints = reshape(map((x,y) -> (x,y), kx[:], ky[:]), res+1, res+1)
    # kpoints1 = [g1.*i .+ g2.*j for i=range(0,stop=1,length=res+1), j=range(0,stop=1,length=res+1)]
    #@show D.h

    ebin=range(minimum(E[:]),maximum(E[:]),step=bin)
    ecenter = (ebin[1:end-1] .+ ebin[2:end]) ./ 2
    dos=zeros(size(ecenter))

    mred(x) = (x-1) % res + 1

    for i=1:res
        for j=1:res
            for b=1:length(bindex)
                # Break each plaquette into two triangles
                dos += filltriangle(ebin,
                    [kpoints[i,j], kpoints[i,j+1], kpoints[i+1,j]],
                    [E[b,i,j], E[b,i,j+1], E[b,i+1,j]])
                dos += filltriangle(ebin,
                    [kpoints[i+1,j+1], kpoints[i,j+1], kpoints[i+1,j]],
                    [E[b,i+1,j+1], E[b,i,j+1], E[b,i+1,j]])
            end
        end
    end
    Ktheta = 8pi/3*sin(theta/2)
    a0=0.246
    #dos = dos * (Ktheta/a0)^2 *sqrt(3)*3/2 / res^2 / (4pi^2) / (ecenter[2]-ecenter[1]) * 4
    dos = dos * (Ktheta/a0)^2 / (4pi^2) / (ecenter[2]-ecenter[1]) * 4
    h=plot(ecenter, dos, xlims=(minimum(E[:]),maximum(E[:])))
    display(h)

    if save != ""
        fp = open(save, "w")
        foreach(ecenter, dos) do e,d
            println(fp,"$e\t$d")
        end
        close(fp)
    end

    (ecenter, dos)
end


function calcTXGBandsFullBZ_screened(Nlayer, config, theta, res, shift; title="", Efield=0, Cq=1000, Cg=0.026, save="", Bin=[0,0], w=[0.08,0.1], anisotropy=1)

    g1=(sqrt(3), 0)
    g2=(sqrt(3)/2, 1.5)
    gamma=(sqrt(3)/2,-0.5)
    k1=(sqrt(3)/2,0.5)
    k2=(0,-1.0)
    kpoints = [g1.*i .+ g2.*j for i=range(0,stop=1,length=res+1) for j=range(0,stop=1,length=res+1)]

    # d=0.33e-9
    # V=(-(Nlayer-1)/2:(Nlayer-1)/2) * (d*Efield) # Interlayer potential
    g = Cg/Cq # Quantum capacitance ratio
    @show V0 = 8.85e-12 * Efield / Cg
    B = zeros(Nlayer,1); B[1] = V0*g; B[end] = -V0*g
    A = diagm(0 => ones(Nlayer)*(1+2g), 1=>ones(Nlayer-1)*(-g), -1=>ones(Nlayer-1)*(-g))
    @show V = (A\B)[:]

    DK=txg_continuum_bands(Nlayer, config, theta, w, shift, V, kpoints; Bin=Bin, anisotropy=anisotropy)
    DKp=txg_continuum_bands(Nlayer, config, theta, w, shift, V, kpoints; Bin=-Bin, anisotropy=anisotropy)
    Nbands = size(DK.E, 1)
    Dirac = Int(Nbands/2)
    #E = D.E[Dirac-7:Dirac+8, :]
    @show size(DK.E)
    @show size(DK.L)
    E = reshape(DK.E, :, res+1, res+1) # E[i,j,k] j,k labels momentum, i labels band
    Ep = reshape(DKp.E, :, res+1, res+1) # E[i,j,k] j,k labels momentum, i labels band
    L = reshape(DK.L, :, Nlayer, res+1, res+1) # band index, layer, and momentum (last two indices)

    # Rotation matrix
    # R3 = [-0.5 sqrt(3)/2; -sqrt(3)/2 -0.5]

    kall = map(x->x.-gamma, kpoints)

    kx = reshape(map(x->x[1], kall), res+1, res+1)
    ky = reshape(map(x->x[2], kall), res+1, res+1)

    pyplot()
    h=plot(legend=false, layout=Nlayer,  camera=(45,5), size=(1200,1600), clims=(0.0,1.0))
    for layer = 1:Nlayer
        for i = -1:2
            plot!(kx[:,:],ky[:,:],E[Dirac+i,:,:], linetype=:surface, xlabel="kx",ylabel="ky",zlabel="E (eV)",fill_z=L[Dirac+i,layer,:,:],clims=(0.0,0.001), subplot=layer, title="Layer $layer")
            # plot!(kx[:,res+2:2res+2],ky[:,res+2:2res+2],E[Dirac+i,:,:], linetype=:surface,fill_z=L[Dirac+i,layer,:,:], subplot=layer)
            # plot!(kx[:,2res+3:3res+3],ky[:,2res+3:3res+3],E[Dirac+i,:,:], linetype=:surface,fill_z=L[Dirac+i,layer,:,:], subplot=layer)
        end
    end

    if save != ""
        fp = open(save, "w")
        for k=Dirac-4:Dirac+5
            for i=1:res+1
                for j=1:res+1
                    println(fp,"$(kx[i,j])\t$(ky[i,j])\t$(E[k,i,j])\t$(Ep[k,i,j])")
                end
                println(fp,"")
            end
            println(fp,"")
        end
        close(fp)
    end

    display(h)

    (kx, ky, E, Ep, Dirac)
end


function calcTXGBandsSquare_screened(Nlayer, maxK, config, theta, res, shift; title="", Efield=0, Cq=1000, Cg=0.026, save="", Bin=[0,0], w=[0.08,0.1], anisotropy=1)

    g1=(1, 0)
    g2=(0, 1)
    gamma=(sqrt(3)/2,-0.5)
    k1=(sqrt(3)/2,0.5)
    k2=(0,-1.0)
    kpoints = [g1.*i .+ g2.*j .+ gamma for i=range(-0.5maxK,stop=0.5maxK,length=res+1) for j=range(-0.5maxK,stop=0.5maxK,length=res+1)]

    # d=0.33e-9
    # V=(-(Nlayer-1)/2:(Nlayer-1)/2) * (d*Efield) # Interlayer potential
    g = Cg/Cq # Quantum capacitance ratio
    @show V0 = 8.85e-12 * Efield / Cg
    B = zeros(Nlayer,1); B[1] = V0*g; B[end] = -V0*g
    A = diagm(0 => ones(Nlayer)*(1+2g), 1=>ones(Nlayer-1)*(-g), -1=>ones(Nlayer-1)*(-g))
    @show V = (A\B)[:]

    DK=txg_continuum_bands(Nlayer, config, theta, w, shift, V, kpoints; Bin=Bin, anisotropy=anisotropy, nostate=true)
    DKp=txg_continuum_bands(Nlayer, config, theta, w, shift, V, kpoints; Bin=-Bin, anisotropy=anisotropy, nostate=true)
    Nbands = size(DK.E, 1)
    Dirac = Int(Nbands/2)
    #E = D.E[Dirac-7:Dirac+8, :]
    @show size(DK.E)
    # @show size(DK.L)
    E = reshape(DK.E, :, res+1, res+1) # E[i,j,k] j,k labels momentum, i labels band
    Ep = reshape(DKp.E, :, res+1, res+1) # E[i,j,k] j,k labels momentum, i labels band
    # L = reshape(DK.L, :, Nlayer, res+1, res+1) # band index, layer, and momentum (last two indices)

    # Rotation matrix
    # R3 = [-0.5 sqrt(3)/2; -sqrt(3)/2 -0.5]

    kall = map(x->x.-gamma, kpoints)

    kx = reshape(map(x->x[1], kall), res+1, res+1)
    ky = reshape(map(x->x[2], kall), res+1, res+1)

    pyplot()
    h=plot(legend=false,  camera=(45,5), size=(1200,1600))
    # for layer = 1:Nlayer
        for i = -1:2
            plot!(kx[:,:],ky[:,:],E[Dirac+i,:,:], linetype=:surface, xlabel="kx",ylabel="ky",zlabel="E (eV)")
            # plot!(kx[:,res+2:2res+2],ky[:,res+2:2res+2],E[Dirac+i,:,:], linetype=:surface,fill_z=L[Dirac+i,layer,:,:], subplot=layer)
            # plot!(kx[:,2res+3:3res+3],ky[:,2res+3:3res+3],E[Dirac+i,:,:], linetype=:surface,fill_z=L[Dirac+i,layer,:,:], subplot=layer)
        end
    # end

    if save != ""
        fp = open(save, "w")
        for k=Dirac-4:Dirac+5
            for i=1:res+1
                for j=1:res+1
                    println(fp,"$(kx[i,j])\t$(ky[i,j])\t$(E[k,i,j])\t$(Ep[k,i,j])")
                end
                println(fp,"")
            end
            println(fp,"")
        end
        close(fp)
    end

    display(h)

    (kx, ky, E, Ep, Dirac)
end
